import taichi as ti
from taichi.math import clamp

from ..material import Material
from ..utils import (
    Vec3,
    Ray3,
    reflect,
    mix,
    PI,
    PI2,
    toNormalHemisphere,
    sampleCosineHemisphere,
    sqr,
)


@ti.func
def smithG_GGX(NoV: float, alphaG: float) -> float:
    return 1 / (NoV + ti.sqrt(mix(sqr(NoV), 1, sqr(ti.max(alphaG, 0.001)))))


@ti.func
def GTR1(NoH: float, a: float) -> float:
    ret = 1 / PI
    if a < 1:
        a2 = sqr(a)
        t = mix(1, a2, sqr(NoH))
        ret = (a2 - 1) / (PI * ti.log(a2) * t)
    return ret


@ti.func
def GTR2(NoH: float, a: float) -> float:
    a2 = sqr(ti.max(a, 0.001))
    t = mix(1, a2, sqr(NoH))
    return a2 / (PI * sqr(t))


@ti.func
def schlickFresnel(u: float) -> float:
    m = clamp(1 - u, 0, 1)
    m2 = sqr(m)
    return sqr(m2) * m


@ti.func
def sampleGTR2(alpha: float) -> Vec3:
    a2 = sqr(ti.max(0.001, alpha))
    phi_h = PI2 * ti.random()
    x_2 = ti.random()
    cos_theta_h = ti.sqrt((1 - x_2) / mix(1, a2, x_2))
    sin_theta_h = ti.sqrt(ti.max(0, 1 - cos_theta_h * cos_theta_h))
    return Vec3(sin_theta_h * ti.cos(phi_h), sin_theta_h * ti.sin(phi_h), cos_theta_h)


@ti.func
def sampleGTR1(alpha: float) -> Vec3:
    a2 = sqr(alpha)
    phi_h = PI2 * ti.random()
    cos_theta_h = ti.sqrt((1 - a2 ** (1 - ti.random())) / (1 - a2))
    sin_theta_h = ti.sqrt(ti.max(0, 1 - cos_theta_h * cos_theta_h))
    return Vec3(sin_theta_h * ti.cos(phi_h), sin_theta_h * ti.sin(phi_h), cos_theta_h)


class DisneyBRDF(Material):
    """迪斯尼的原理化BRDF材质"""

    def __init__(
        self,
        roughness=0.5,
        metallic=0.0,
        specular=0.5,
        specularTint=0.0,
        sheen=0.0,
        sheenTint=0.5,
        clearCoat=0.0,
        clearCoatGloss=1.0,
        subsurface=0.0,
    ):
        super().__init__()
        Material.materials[self.index].attribute[0] = roughness
        Material.materials[self.index].attribute[1] = metallic
        Material.materials[self.index].attribute[2] = specular
        Material.materials[self.index].attribute[3] = specularTint
        Material.materials[self.index].attribute[4] = sheen
        Material.materials[self.index].attribute[5] = sheenTint
        Material.materials[self.index].attribute[6] = clearCoat
        Material.materials[self.index].attribute[7] = clearCoatGloss
        Material.materials[self.index].attribute[8] = subsurface

    @ti.func
    def scatter(self: int, ray: Ray3, hitPoint: Vec3, norm: Vec3):
        ray.origin = hitPoint
        if norm.dot(ray.direct) > 0:
            norm = -norm
        V = -ray.direct
        ray.direct = DisneyBRDF.sample(self, ray, norm)
        color = Vec3(0)
        NoL = norm.dot(ray.direct)
        if NoL > 0:
            roughness = Material.materials[self].attribute[0]
            metallic = Material.materials[self].attribute[1]
            specular = Material.materials[self].attribute[2]
            specularTint = Material.materials[self].attribute[3]
            sheen = Material.materials[self].attribute[4]
            sheenTint = Material.materials[self].attribute[5]
            clearCoat = Material.materials[self].attribute[6]
            clearCoatGloss = Material.materials[self].attribute[7]
            subsurface = Material.materials[self].attribute[8]
            baseColor = Material.getColor(self, hitPoint)
            NoV = norm.dot(V)
            H = (ray.direct + V).normalized()
            NoH = norm.dot(H)
            LoH = abs(ray.direct.dot(H))  # 不取绝对值的话 pdf 可能取负

            Cd_lum = Vec3(0.3, 0.6, 0.1).dot(baseColor)
            C_tint = ti.select(Cd_lum > 0, baseColor / Cd_lum, 1)
            C_spec = specular * ti.math.mix(1, C_tint, specularTint)
            C_spec0 = ti.math.mix(0.08 * C_spec, baseColor, metallic)
            C_sheen = ti.math.mix(1, C_tint, sheenTint)

            FL, FV = schlickFresnel(NoL), schlickFresnel(NoV)
            Fss90 = sqr(LoH) * roughness
            Fd90 = 0.5 + 2 * Fss90
            Fd = ti.math.mix(1, Fd90, FL) * ti.math.mix(1, Fd90, FV)
            Fss = ti.math.mix(1, Fss90, FL) * ti.math.mix(1, Fss90, FV)
            ss = 1.25 * (Fss * (1 / (NoL + NoV) - 0.5) + 0.5)

            FH = schlickFresnel(LoH)
            F_sheen = FH * sheen * C_sheen

            _diffuse = (ti.math.mix(Fd, ss, subsurface) * baseColor / PI + F_sheen) * (
                1 - metallic
            )

            Ds = GTR2(NoH, sqr(roughness))
            Fs = ti.math.mix(C_spec0, 1, FH)
            Gs = smithG_GGX(NoL, roughness) * smithG_GGX(NoV, roughness)
            _specular = Gs * Fs * Ds

            Dr = GTR1(NoH, ti.math.mix(0.1, 0.001, clearCoatGloss))
            Fr = ti.math.mix(0.04, 1, FH)
            Gr = smithG_GGX(NoL, 0.25) * smithG_GGX(NoV, 0.25)
            _clearCoat = 0.25 * clearCoat * Gr * Fr * Dr
            prob = Vec3(1 - metallic, 1, 0.25 * clearCoat)
            pdf = (
                prob * Vec3(NoL / PI, Ds * NoH / (4 * LoH), Dr * NoH / (4 * LoH))
            ).sum() / prob.sum()
            color = (_diffuse + _specular + _clearCoat) * NoL / pdf
            if pdf < 0:
                print(Ds, NoH, LoH)
        return color, ray

    @ti.func
    def sample(self: int, ray, norm: Vec3) -> Vec3:
        roughness = Material.materials[self].attribute[0]
        metallic = Material.materials[self].attribute[1]
        clearCoat = Material.materials[self].attribute[6]
        clearCoatGloss = Material.materials[self].attribute[7]
        p_diffuse = 1 - metallic
        p_specular = 1.0
        p_clearCoat = 0.25 * clearCoat
        p = ti.random() * (p_diffuse + p_specular + p_clearCoat)
        L = Vec3(0)
        if p <= p_diffuse:
            L = toNormalHemisphere(sampleCosineHemisphere(), norm)
        elif p <= p_diffuse + p_specular:
            L = reflect(
                ray.direct, toNormalHemisphere(sampleGTR2(sqr(roughness)), norm)
            )
        else:
            L = reflect(
                ray.direct,
                toNormalHemisphere(sampleGTR1(mix(0.1, 0.001, clearCoatGloss)), norm),
            )
        return L
